# tddft_benchmarking.py
# TD-DFT vertical excitation benchmarking
# Benchmarks: B3LYP, PBE0, CAM-B3LYP, ωB97X-D

from pyscf import gto, dft, tdscf
from pyscf.solvent import pcm
import numpy as np

print("="*80)
print("TD-DFT BENCHMARKING FOR CAMPHORQUINONE")
print("="*80)
print()

# Load optimized geometry from XYZ file
print("Loading optimized geometry from camphorquinone_optimized.xyz...")
with open('camphorquinone_optimized.xyz', 'r') as f:
    lines = f.readlines()

# Parse XYZ format
xyz_lines = lines[2:]
atom_list = []
for line in xyz_lines:
    parts = line.split()
    if len(parts) >= 4:
        atom = parts[0]
        x, y, z = map(float, parts[1:4])
        atom_list.append([atom, (x, y, z)])

print(f"✓ Loaded {len(atom_list)} atoms")
print()

# Build molecule
mol = gto.Mole()
mol.atom = atom_list
mol.basis = '6-31g(d)'
mol.charge = 0
mol.spin = 0
mol.verbose = 3
mol.build()

# Functional definitions
functionals = {
    'B3LYP': 'B3LYP',
    'PBE0': 'PBE0',
    'CAM-B3LYP': 'CAM-B3LYP',
    'WB97X-D': 'WB97X-D'  # Note: PySCF uses 'WB97X-D' not 'ωB97X-D'
}

# Experimental reference
exp_wavelength = 468.0  # nm
exp_energy = 1239.84 / exp_wavelength  # eV

print("="*80)
print("EXPERIMENTAL REFERENCE")
print("="*80)
print(f"  λ_max = {exp_wavelength:.1f} nm")
print(f"  E_exp = {exp_energy:.4f} eV")
print()

# Results storage
results = {}

for func_name, xc_code in functionals.items():
    print("="*80)
    print(f"FUNCTIONAL: {func_name}")
    print("="*80)
    print()
    
    # Ground state DFT
    print(f"Running ground-state DFT with {func_name}...")
    mf = dft.RKS(mol)
    mf.xc = xc_code
    mf = pcm.PCM(mf)
    mf.with_solvent.eps = 4.0
    mf.verbose = 3
    E_gs = mf.kernel()
    
    print(f"  Ground state energy: {E_gs:.10f} Hartree")
    print()
    
    # TD-DFT
    print(f"Running TD-DFT for 10 excited states...")
    td = tdscf.TDDFT(mf)
    td.nstates = 10
    td.verbose = 3
    td.kernel()
    
    # Extract results
    exc_energies_au = td.e  # in Hartree
    exc_energies_ev = exc_energies_au * 27.2114  # Convert to eV
    osc_strengths = td.oscillator_strength()
    
    # Find first bright state (f > 0.01)
    bright_idx = None
    for i, f in enumerate(osc_strengths):
        if f > 0.01:
            bright_idx = i
            break
    
    if bright_idx is not None:
        E_bright = exc_energies_ev[bright_idx]
        f_bright = osc_strengths[bright_idx]
        lambda_bright = 1239.84 / E_bright
        delta_lambda = lambda_bright - exp_wavelength
        
        results[func_name] = {
            'energy_eV': E_bright,
            'wavelength_nm': lambda_bright,
            'oscillator': f_bright,
            'delta_nm': delta_lambda,
            'state_index': bright_idx + 1
        }
        
        print()
        print(f"FIRST BRIGHT STATE (S{bright_idx + 1}):")
        print(f"  E = {E_bright:.4f} eV")
        print(f"  λ = {lambda_bright:.1f} nm")
        print(f"  f = {f_bright:.4f}")
        print(f"  Δλ = {delta_lambda:+.1f} nm vs. experiment")
    else:
        print("  WARNING: No bright state found (f > 0.01)")
        results[func_name] = None
    
    print()
    
    # Write detailed output file
    output_filename = f"tddft_{func_name.lower().replace('-', '')}.out"
    with open(output_filename, "w") as f:
        f.write("="*80 + "\n")
        f.write(f"TD-DFT VERTICAL EXCITATION ENERGIES\n")
        f.write(f"Functional: {func_name}\n")
        f.write(f"Basis: 6-31G(d)\n")
        f.write(f"Solvent: PCM (ε = 4.0)\n")
        f.write("="*80 + "\n\n")
        
        f.write(f"Ground state energy: {E_gs:.10f} Hartree\n\n")
        
        f.write("EXCITED STATES:\n")
        f.write("-"*80 + "\n")
        f.write(f"{'State':>6} {'Energy (eV)':>12} {'λ (nm)':>10} {'f':>10} {'Character':>15}\n")
        f.write("-"*80 + "\n")
        
        for i, (e, f_osc) in enumerate(zip(exc_energies_ev, osc_strengths)):
            wavelength = 1239.84 / e
            character = "bright" if f_osc > 0.01 else "dark"
            marker = " ←" if i == bright_idx else ""
            f.write(f"  S{i+1:>2d}   {e:10.4f}   {wavelength:8.1f}   {f_osc:8.4f}   {character:>10}{marker}\n")
        
        f.write("-"*80 + "\n\n")
        
        if bright_idx is not None:
            f.write("BENCHMARK COMPARISON:\n")
            f.write(f"  Calculated λ_max: {lambda_bright:.1f} nm (S{bright_idx+1})\n")
            f.write(f"  Experimental λ_max: {exp_wavelength:.1f} nm\n")
            f.write(f"  Deviation: {delta_lambda:+.1f} nm\n")
            f.write(f"  Oscillator strength: {f_bright:.4f}\n")
    
    print(f"✓ Detailed output saved to {output_filename}")
    print()

# Summary comparison
print("="*80)
print("BENCHMARKING SUMMARY")
print("="*80)
print()
print(f"{'Functional':<15} {'E (eV)':>10} {'λ (nm)':>10} {'f':>8} {'Δλ (nm)':>10}")
print("-"*80)

for func_name in functionals.keys():
    if results[func_name]:
        r = results[func_name]
        print(f"{func_name:<15} {r['energy_eV']:>10.4f} {r['wavelength_nm']:>10.1f} "
              f"{r['oscillator']:>8.4f} {r['delta_nm']:>+10.1f}")

print("-"*80)
print(f"{'Experiment':<15} {exp_energy:>10.4f} {exp_wavelength:>10.1f}")
print()

# Identify best functional
if any(results.values()):
    best_func = min((k for k in results if results[k]), 
                    key=lambda k: abs(results[k]['delta_nm']))
    print(f"✓ BEST AGREEMENT: {best_func} (Δλ = {results[best_func]['delta_nm']:+.1f} nm)")
else:
    print("⚠ No valid results obtained")

print()
print("="*80)
print("TD-DFT BENCHMARKING COMPLETE")
print("="*80)
print()
print("Output files generated:")
for func_name in functionals.keys():
    filename = f"tddft_{func_name.lower().replace('-', '')}.out"
    print(f"  ✓ {filename}")
print()
